/* Copyright 2018-2020 Intel Corporation
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

/**
 * @file
 * Secret key (symmetric) encryption.
 * Uses AES-GCM 256, which also includes authentication.
 *
 * Lower-level functions implemented using Mbed TLS.
 * See also skenc_common.cpp for Mbed TLS-independent code.
 */

#include <string.h>  // memcmp()
#include <mbedtls/gcm.h>

#include "crypto_shared.h"
#include "error.h"
#include "skenc.h"

#ifndef CRYPTOLIB_MBEDTLS
#error "CRYPTOLIB_MBEDTLS must be defined to compile source with Mbed TLS."
#endif

namespace pcrypto = tcf::crypto;
namespace Error = tcf::error; // Error handling
namespace constants = tcf::crypto::constants;


/**
 * Encrypt a message using AES-GCM authenticated encryption.
 *
 * Appends a 16 byte (128 bit) authentication tag (sometimes called a MAC)
 * to the output cipher text:
 *     message = ciphertext + authentication tag
 * The authentication tag is not encrypted.
 *
 * Throws RuntimeError, ValueError.
 *
 * @param message binary data to encrypt
 * @param key Secret AES-256 encryption key.
 *            Generated by GenerateKey()
 * @param iv  96-bit initialization Vector (IV). Generated by GenerateIV()
 * @returns Byte array containing encrypted data and appended auth tag
 */
ByteArray pcrypto::skenc::EncryptMessage(
        const ByteArray& key, const ByteArray& iv, const ByteArray& message) {
    mbedtls_gcm_context aes_gcm;
    int rc;
    size_t pt_len = message.size();
    unsigned char* pt = (unsigned char*)message.data();
    size_t ct_buf_len = pt_len + constants::BLOCK_LENGTH;
    ByteArray ct(ct_buf_len);
    const size_t ct_len = pt_len + constants::TAG_LEN;
    unsigned char tag[constants::TAG_LEN];

    // Sanity checks
    if (key.size() != constants::SYM_KEY_LEN) {
        std::string msg(
            "Crypto Error (EncryptMessage): Wrong AES-GCM key length");
        throw Error::ValueError(msg);
    }

    if (iv.size() != constants::IV_LEN) {
        std::string msg(
            "Crypto Error (EncryptMessage): Wrong AES-GCM IV length");
        throw Error::ValueError(msg);
    }

    if (message.size() == 0) {
        std::string msg(
            "Crypto Error (EncryptMessage): Cannot encrypt the empty message");
        throw Error::ValueError(msg);
    }

    // Initialize encryption.
    // MbedTLS expects key length is in bits and IV length in bytes.
    mbedtls_gcm_init(&aes_gcm);
    rc = mbedtls_gcm_setkey(&aes_gcm, MBEDTLS_CIPHER_ID_AES,
        (const unsigned char*)key.data(), key.size() * 8);
    if (rc != 0) {
        mbedtls_gcm_free(&aes_gcm);
        std::string msg(
            "Crypto Error (EncryptMessage): Mbed TLS could not set AES key");
        throw Error::RuntimeError(msg);
    }

    rc= mbedtls_gcm_starts(&aes_gcm, MBEDTLS_GCM_ENCRYPT,
        (const unsigned char*)iv.data(), constants::IV_LEN, nullptr, 0);
    if (rc != 0) {
        mbedtls_gcm_free(&aes_gcm);
        std::string msg(
            "Crypto Error (EncryptMessage): Mbed TLS could not set AES GCM IV");
        throw Error::RuntimeError(msg);
    }

    // Encrypt message (with no IV prepended)
    rc = mbedtls_gcm_update(&aes_gcm, pt_len, pt, ct.data());
    if (rc != 0) {
        mbedtls_gcm_free(&aes_gcm);
        std::string msg(
            "Crypto Error (EncryptMessage): Mbed TLS could not update "
            "AES-GCM encryption");
        throw Error::RuntimeError(msg);
    }

    // Generate message's auth tag
    rc = mbedtls_gcm_finish(&aes_gcm, tag, constants::TAG_LEN);
    if (rc != 0) {
        mbedtls_gcm_free(&aes_gcm);
        std::string msg("Crypto Error (EncryptMessage): "
            "Mbed TLS could not get AES-GCM TAG");
        throw Error::RuntimeError(msg);
    }

    // Build and return encrypted output string with auth tag appended
    ct.resize(ct_len - constants::TAG_LEN);
    ct.insert(ct.end(), tag, tag + constants::TAG_LEN);

    // Cleanup and return
    mbedtls_gcm_free(&aes_gcm);

    return ct;
}  // pcrypto::skenc::EncryptMessage


/*
 * Decrypt message.data() using AES-GCM authenticated decryption.
 *
 * Expects a 12 byte (96 bit) IV (sometimes called a nonce) and
 * an 16 byte (128 bit) authentication tag (sometimes called a MAC),
 * prepended and appended, respectively, the input cipher text:
 *     message = ciphertext + authentication tag
 * The authentication tag is not encrypted.
 *
 * Throws RuntimeError, ValueError,
 * CryptoError (message authentication failure).
 *
 * @param key         Secret AES-256 encryption key.
 *                    Generated by GenerateKey()
 * @param iv          96-bit initialization Vector (IV).
 *                    Generated by GenerateIV()
 * @param message     binary data to decrypt. Generated by EncryptMessage()
 *                    Includes appended authentication tag.
 *                    IV is separate (not prepended to message)
 * @param message_len Length of message in bytes
 * @returns Byte array containing decrypted data
 */
ByteArray pcrypto::skenc::DecryptMessage(
        const ByteArray& key, const char iv[constants::IV_LEN],
        const char *message, size_t message_len) {
    mbedtls_gcm_context aes_gcm;
    // Both plaintext and cryptotext length (excluding auth tag length)
    const size_t len = message_len - constants::TAG_LEN;
    ByteArray pt(message_len);
    const unsigned char *tag_expected = (unsigned char *)message + len;
    unsigned char tag_generated[constants::TAG_LEN];
    int rc;

    // Sanity checks
    if (key.size() != constants::SYM_KEY_LEN) {
        std::string msg(
            "Crypto Error (DecryptMessage): Wrong AES-GCM key length");
        throw Error::ValueError(msg);
    }

    if (message_len < constants::TAG_LEN) {
        std::string msg(
            "Crypto Error (DecryptMessage): AES-GCM message smaller "
            "than minimum length (TAG length)");
        throw Error::ValueError(msg);
    }

    // Initialize decryption.
    // MbedTLS expects key length is in bits and IV length in bytes.
    mbedtls_gcm_init(&aes_gcm);
    rc = mbedtls_gcm_setkey(&aes_gcm, MBEDTLS_CIPHER_ID_AES,
        (const unsigned char*)key.data(), key.size() * 8);
    if (rc != 0) {
        mbedtls_gcm_free(&aes_gcm);
        std::string msg(
            "Crypto Error (DecryptMessage): Mbed TLS could not set AES key");
        throw Error::RuntimeError(msg);
    }

    rc= mbedtls_gcm_starts(&aes_gcm, MBEDTLS_GCM_DECRYPT,
        (const unsigned char*)iv, constants::IV_LEN, nullptr, 0);
    if (rc != 0) {
        mbedtls_gcm_free(&aes_gcm);
        std::string msg(
            "Crypto Error (DecryptMessage): Mbed TLS could not set AES GCM IV");
        throw Error::RuntimeError(msg);
    }

    // Decrypt message (IV is separate; omit appended auth tag)
    rc = mbedtls_gcm_update(&aes_gcm, len, (const unsigned char *)message,
        (unsigned char *)pt.data());
    if (rc != 0) {
        mbedtls_gcm_free(&aes_gcm);
        std::string msg(
            "Crypto Error (DecryptMessage): Mbed TLS could not update "
            "AES-GCM decryption");
        throw Error::RuntimeError(msg);
    }

    // Generate the auth tag from decrypting the message
    rc = mbedtls_gcm_finish(&aes_gcm, tag_generated, constants::TAG_LEN);
    if (rc != 0) {
        mbedtls_gcm_free(&aes_gcm);
        std::string msg("Crypto Error (DecryptMessage): "
            "Mbed TLS could not get AES-GCM TAG");
        throw Error::RuntimeError(msg);
    }

    // Compare expected tag from the input cipher text with the generated tag
    rc = memcmp(tag_expected, tag_generated, constants::TAG_LEN);
    if (rc != 0) {
        mbedtls_gcm_free(&aes_gcm);
        std::string msg(
            "Crypto Error (DecryptMessage): AES_GCM authentication "
            "failed, plaintext is not trustworthy");
        throw Error::CryptoError(msg);
    }

    // Build and return decrypted output string
    pt.resize(len);

    return pt;
}  // pcrypto::skenc::DecryptMessage
